
#' This function uses 'ape' package.
#' @param obj can be a tree table in dataframe or the refMod object created by CreateHier function.
PlotTopoTree <- function(obj, ...){
  library(data.tree)
  library(DiagrammeR)

  if(class(obj) == "data.frame"){
    obj <- data.frame(lapply(obj, function(x) {gsub("\\+|-|/", ".", x)}))
    obj$pathString <- apply(cbind("TaxaRoot", obj), 1, paste0, collapse="/")
    taxa <- as.Node(obj)
  }else if(class(obj) == "RefMod"){
    taxa <- data.tree::as.Node(obj@tree[[1]])
  }
  SetGraphStyle(taxa, rankdir = "LR")
  SetEdgeStyle(taxa, arrowhead = "vee", color = "grey35", penwidth = "2px")
  taxa$Do(function(node) SetNodeStyle(node, shape = "box", fillcolor = "chocolate2",style = "filled,rounded",fontname = "helvetica", fontcolor="black",tooltip = GetDefaultTooltip, width=3), filterFun = isNotLeaf)
  taxa$Do(function(node) SetNodeStyle(node, shape = "egg", fillcolor = "green",style = "filled,rounded",fontname = "helvetica", fontcolor="black",tooltip = GetDefaultTooltip, width=3), filterFun = isLeaf)

  pp <- plot(taxa, direction = "descend")

  return(pp)
}

#' Plots the node accuracies on the tree.
#' @param refMod hiermod
PlotTopoNodeAcc <- function(refMod, ...){
  library(data.tree)
  library(DiagrammeR)
  taxa <- data.tree::as.Node(refMod@tree[[1]])
  SetGraphStyle(taxa, rankdir = "LR")
  SetEdgeStyle(taxa, arrowhead = "vee", color = "grey35", penwidth = "2px")

  #Extract the node accuracy metrics:
  nodeStats <- NULL
  for(i in names(refMod@model)){
    mtry <- refMod@model[[as.character(i)]]$finalModel$mtry
    nodeStats <- rbind(nodeStats,
                       cbind(node=i,
                             refMod@model[[as.character(i)]]$results[which(refMod@model[[as.character(i)]]$results$mtry == mtry),],
                             NodeLabel=refMod@tree[[1]]$node.label[ as.numeric(i) - length(refMod@tree[[1]]$tip.label)],
                             classSize=length(refMod@model[[as.character(i)]]$levels)))
  }
  nodeStats
  nodeStats <- nodeStats[which(nodeStats$NodeLabel %in% refMod@tree[[1]]$node.label),]
  rownames(nodeStats) <- nodeStats$NodeLabel
  nodeAcc <- round(nodeStats$Accuracy*100, digits = 1)
  names(nodeAcc) <- nodeStats$NodeLabel

  Rand <- function(x){
    aa <- FixLab(xstring = x$name)
    nodeAcc[aa]
  }
  countPct <- as.character(taxa$Get(Rand))
  countPct[is.na(countPct)] <- '0'
  countPct <- as.numeric(countPct)
  taxa$Set(Perct=countPct)

  cols <- colorRampPalette(c("white", "cyan"))(101)
  SetNodeStyle(taxa,
               style = "filled,rounded",
               shape = "box",
               label = function(node) paste(node$name,"\n", node$Perct,"%",sep=""),
               fillcolor = function(node) cols[as.numeric(node$Perct)+1],
               fontname = "helvetica",
               fontcolor = "black",
               tooltip = function(node) paste(node$Perct,"%"),
               width=3)
  pp <- plot(taxa, direction = "descend")
  print(NodesAcc(HieRMod = refMod))
  return(pp)
}

#' Plots the accumulative percent projections on the tree.
#' @param HieRobj the object generated by HieRFIT function.
#' @param aggregate binary to determine whether the results will be aggregated or not.
PlotTopoStats <- function(HieRobj, aggregate=TRUE, ...){### Update with data.tree plotting
  library(data.tree)
  library(DiagrammeR)
  taxa <- data.tree::as.Node(HieRobj@tree[[1]])
  SetGraphStyle(taxa, rankdir = "LR")
  SetEdgeStyle(taxa, arrowhead = "vee", color = "grey35", penwidth = "2px")
  freq <- table(HieRobj@Evaluation$Projection)
  names(freq) <- FixLab(xstring = names(freq))
  freq <- round(freq*100/sum(freq), digits = 1)
  Rand <- function(x){
    aa <- FixLab(xstring = x$name)
    freq[aa]
  }
  countPct <- as.character(taxa$Get(Rand))
  countPct[is.na(countPct)] <- '0'
  countPct <- as.numeric(countPct)
  taxa$Set(Perct=countPct)
  if(aggregate){
    taxa$Do(function(node) node$Perct <- node$Perct + Aggregate(node, attribute = "Perct", aggFun = sum), filterFun = isNotLeaf, traversal = "post-order")
  }
  cols <- colorRampPalette(c("white", "green"))(101)
  SetNodeStyle(taxa,
               style = "filled,rounded",
               shape = "box",
               label = function(node) paste(node$name,"\n", node$Perct,"%",sep=""),
               fillcolor = function(node) cols[as.numeric(node$Perct)+1],
               fontname = "helvetica",
               fontcolor = "black",
               tooltip = function(node) paste(node$Perct,"%"),
               width=3)
  pp <- plot(taxa, direction = "descend")

  return(pp)
}

#' Plots the projection results with a barplot
#' @param HieRobj hierobject generated by HieRFIT function.
PlotBarStats <- function(HieRobj){
  p <- ggplot(HieRobj@Evaluation)+
        geom_histogram(aes(x=Projection, fill=Projection), stat = "count")+
        theme(axis.text.x = element_text(angle = 90, hjust = 1), legend.position="none")+
        labs(y="Number of Cells (bars)", title=paste("Projection outcome", sep=""))
  return(p)
}

#' Cross-comparison plot between projections and prior class labels.
#' @param HieRobj hierobject generated by HieRFIT function.
#' @param Prior prior class labels if available.
CrossCheck <- function(HieRobj, Prior=NULL){
  #This function takes a table with two columns of which first is for prior cell labels and second is for predicted cell class.
  library(tidyverse)
  library(alluvial)
  library(ggalluvial)

  if(!identical(HieRobj@Prior, character(0))){
    PriorPostTable <- data.frame(Prior=HieRobj@Prior, Post=HieRobj@Evaluation$Projection)
  }else if(is.null(Prior)){
    stop("Please, provide a list of Prior labels using 'Prior' argument!")
  }else{
    PriorPostTable <- data.frame(Prior=Prior, Post=HieRobj@Evaluation$Projection)
  }

  crx <- PriorPostTable %>% group_by_at(vars(one_of(names(PriorPostTable))))%>% tally() %>% arrange(desc(n))

  p5 <- ggplot(crx, aes_string(y = "n", axis1 = names(crx)[1], axis2 = names(crx)[2] )) +
    geom_alluvium(aes_string(fill = names(crx)[2]), width = 1/12) +
    geom_stratum(width = 1/12, fill = "black", color = "red") +
    geom_label(stat = "stratum", label.strata = TRUE) +
    scale_x_discrete(limits = c("PriorLabels", "Projections"), expand = c(.05, .05)) +
    ggtitle("Predictions Cross-Check")

  return(p5)
}
